/** Copyright 2020-2023 Alibaba Group Holding Limited.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#ifndef MODULES_BASIC_DS_HASHMAP_VINEYARD_MOD_
#define MODULES_BASIC_DS_HASHMAP_VINEYARD_MOD_

#include <algorithm>
#include <functional>
#include <memory>
#include <string>
#include <utility>

#include "flat_hash_map/flat_hash_map.hpp"

#include "cityhash/cityhash.hpp"
#include "wyhash/wyhash.hpp"

#include "basic/ds/array.vineyard.h"
#include "basic/ds/arrow.vineyard.h"  // IWYU pragma: keep
#include "basic/ds/arrow_utils.h"
#include "basic/utils.h"
#include "client/ds/blob.h"
#include "client/ds/i_object.h"
#include "common/util/arrow.h"
#include "grape/vertex_map/idxers/pthash_idxer.h"

namespace vineyard {

#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wattributes"
#endif

struct prime_hash_policy {
  prime_hash_policy() : current_prime_(1) {}
  prime_hash_policy(const prime_hash_policy& rhs)
      : current_prime_(rhs.current_prime_) {}
  prime_hash_policy& operator=(const prime_hash_policy& rhs) {
    current_prime_ = rhs.current_prime_;
    return *this;
  }

  size_t index_for_hash(size_t hash) const { return hash % current_prime_; }

  void set_prime(size_t prime) { current_prime_ = prime; }

 private:
  size_t current_prime_;
};

// Using prime_hash_policy in flat_hash_map.

template <typename T>
struct prime_number_hash : std::hash<T> {
  typedef ska::prime_number_hash_policy hash_policy;
};

template <typename U>
struct typename_t<prime_number_hash<U>> {
  inline static const std::string name() { return type_name<std::hash<U>>(); }
};

template <typename T>
struct prime_number_hash_wy : wy::hash<T> {
  typedef ska::prime_number_hash_policy hash_policy;
};

template <typename U>
struct typename_t<prime_number_hash_wy<U>> {
  inline static const std::string name() { return type_name<wy::hash<U>>(); }
};

template <typename T>
struct prime_number_hash_city : city::hash<T> {
  typedef ska::prime_number_hash_policy hash_policy;
};

template <typename U>
struct typename_t<prime_number_hash_city<U>> {
  inline static const std::string name() { return type_name<city::hash<U>>(); }
};

template <typename K, typename V, typename H, typename E>
class HashmapBaseBuilder;

template <typename K, typename V>
class PerfectHashmapBaseBuilder;

/**
 * @brief The hash map in vineyard.
 *
 * @tparam K The type for the key.
 * @tparam V The type for the value.
 * @tparam std::hash<K> The hash function for the key.
 * @tparam std::equal_to<K> The compare function for the key.
 */
template <typename K, typename V, typename H = prime_number_hash_wy<K>,
          typename E = std::equal_to<K>>
class [[vineyard]] Hashmap : public Registered<Hashmap<K, V, H, E>>,
                             public H,
                             public E {
 public:
  using KeyHash = H;
  using KeyEqual = E;

  using T = std::pair<K, V>;

  using Entry = ska::detailv3::sherwood_v3_entry<T>;
  using EntryPointer = const Entry*;

  using Hasher = ska::detailv3::KeyOrValueHasher<K, std::pair<K, V>, H>;
  using Equal = ska::detailv3::KeyOrValueEquality<K, std::pair<K, V>, E>;

  /**
   * @brief Set the hash policy after the construction of the HashMap.
   *
   */
  void PostConstruct(const ObjectMeta& meta) override {
    hash_policy_.set_prime(num_slots_minus_one_ + 1);

    if (data_buffer_mapped_ != nullptr) {
      // value in original string_view: k1 = data_buffer + offset
      // value in mmap-ed string_view: k2 = data_buffer_mapped_ + offset
      //
      // -> k2 = k1 + diff
      diff_ = reinterpret_cast<uintptr_t>(data_buffer_mapped_->data()) -
              data_buffer_;
    }
  }

  using value_type = T;
  using size_type = size_t;
  using difference_type = std::ptrdiff_t;
  using hasher = H;
  using key_equal = E;
  using reference = value_type&;
  using const_reference = const value_type&;
  using pointer = value_type*;
  using const_pointer = value_type*;

  using flat_hash_table_type = ska::detailv3::sherwood_v3_table<
      T, K, H, Hasher, E, Equal, std::allocator<T>,
      typename std::allocator_traits<std::allocator<T>>::template rebind_alloc<
          ska::detailv3::sherwood_v3_entry<T>>>;

  /**
   * @brief The iterator to iterate key-value mappings in the HashMap.
   *
   */
  struct iterator {
    iterator() = default;
    explicit iterator(EntryPointer current) : current(current) {}
    EntryPointer current = EntryPointer();

    friend bool operator==(const iterator& lhs, const iterator& rhs) {
      return lhs.current == rhs.current;
    }

    friend bool operator!=(const iterator& lhs, const iterator& rhs) {
      return lhs.current != rhs.current;
    }

    iterator& operator++() {
      do {
        ++current;
      } while (current->is_empty());
      return *this;
    }

    iterator operator++(int) {
      iterator copy(*this);
      ++*this;
      return copy;
    }

    const value_type& operator*() const { return current->value; }

    const value_type* operator->() const {
      return std::addressof(current->value);
    }
  };

  /**
   * @brief The beginning iterator.
   *
   */
  iterator begin() const {
    for (EntryPointer it = entries_.data();; ++it) {
      if (it->has_value()) {
        return iterator(it);
      }
    }
  }

  /**
   * @brief The ending iterator.
   *
   */
  iterator end() const {
    return iterator(entries_.data() + static_cast<ptrdiff_t>(
                                          num_slots_minus_one_ + max_lookups_));
  }

  /**
   * @brief Find the iterator by key.
   *
   */
  iterator find(const K& key) {
    size_t index = hash_policy_.index_for_hash(hash_object(key));
    EntryPointer it = entries_.data() + static_cast<ptrdiff_t>(index);
    for (int8_t distance = 0; it->distance_from_desired >= distance;
         ++distance, ++it) {
      if (compares_equal(key, it->value.first)) {
        return iterator(it);
      }
    }
    return end();
  }

  /**
   * @brief Return the const iterator by key.
   *
   */
  const iterator find(const K& key) const {
    return const_cast<Hashmap<K, V, H, E>*>(this)->find(key);
  }

  /**
   * @brief Return the number of occurancies of the key.
   *
   */
  size_t count(const K& key) const { return find(key) == end() ? 0 : 1; }

  /**
   * @brief Return the size of the HashMap, i.e., the number of elements stored
   * in the HashMap.
   *
   */
  size_t size() const { return num_elements_; }

  /**
   * @brief Return the max size of the HashMap, i.e., the number of allocated
   * cells for elements stored in the HashMap.
   *
   */
  size_t bucket_count() const {
    return num_slots_minus_one_ ? num_slots_minus_one_ + 1 : 0;
  }

  /**
   * @brief Return the load factor of the HashMap.
   *
   */
  float load_factor() const {
    size_t bucket_count = num_slots_minus_one_ ? num_slots_minus_one_ + 1 : 0;
    if (bucket_count) {
      return static_cast<float>(num_elements_) / bucket_count;
    } else {
      return 0.0f;
    }
  }

  /**
   * @brief Check whether the HashMap is empty.
   *
   */
  bool empty() const { return num_elements_ == 0; }

  /**
   * @brief Get the value by key.
   * Here the existence of the key is checked.
   */
  const V& at(const K& key) const {
    auto found = this->find(key);
    if (found == this->end()) {
      throw std::out_of_range("Argument passed to at() was not in the map.");
    }
    return found->second;
  }

 private:
  [[shared]] size_t num_slots_minus_one_;
  [[shared]] int8_t max_lookups_;
  [[shared]] size_t num_elements_;
  [[shared]] Array<Entry> entries_;

  prime_hash_policy hash_policy_;

  // used for std::hashmap<string_view, V> only.
  [[shared]] uintptr_t data_buffer_ = 0;
  [[shared]] std::shared_ptr<Blob> data_buffer_mapped_;
  ptrdiff_t diff_ = 0;

  friend class Client;
  friend class HashmapBaseBuilder<K, V, H, E>;

  size_t hash_object(const K& key) const {
    return static_cast<const H&>(*this)(key);
  }

  template <typename KT,
            typename std::enable_if<!std::is_same<KT, arrow_string_view>::value,
                                    bool>::type* = nullptr>
  bool compares_equal(const KT& lhs, const KT& rhs) const {
    return static_cast<const E&>(*this)(lhs, rhs);
  }

  template <typename KT, typename std::enable_if<std::is_same<
                             KT, arrow_string_view>::value>::type* = nullptr>
  bool compares_equal(const KT& lhs, const KT& rhs) const {
    return static_cast<const E&>(*this)(
        lhs, arrow_string_view{rhs.data() + diff_, rhs.size()});
  }
};

using grape::PTHashIdxer;
using grape::PTHashIdxerBuilder;

namespace detail {
namespace perfect_hash {

template <typename Value, typename Array>
struct arrow_array_iterator {
 public:
  using value_type = Value;
  explicit arrow_array_iterator(const typename Array::IteratorType& iterator)
      : iterator_(iterator) {}

  arrow_array_iterator(const arrow_array_iterator& other) {
    iterator_ = other.iterator_;
  }

  arrow_array_iterator(arrow_array_iterator&& other) {
    iterator_ = std::move(other.iterator_);
  }

  arrow_array_iterator& operator=(const arrow_array_iterator& other) {
    iterator_ = other.iterator_;
    return *this;
  }

  arrow_array_iterator& operator=(arrow_array_iterator&& other) {
    iterator_ = std::move(other.iterator_);
    return *this;
  }

  ~arrow_array_iterator() = default;

  value_type operator*() const { return (*iterator_).value(); }

  arrow_array_iterator& operator++() {
    ++iterator_;
    return *this;
  }

  arrow_array_iterator operator++(int) {
    arrow_array_iterator tmp(*this);
    ++iterator_;
    return tmp;
  }

  bool operator==(const arrow_array_iterator& other) const {
    return iterator_ == other.iterator_;
  }

  bool operator!=(const arrow_array_iterator& other) const {
    return iterator_ != other.iterator_;
  }

 private:
  typename Array::IteratorType iterator_;
};

template <typename K, typename V>
Status build_values(
    PTHashIdxer<K, uint64_t>& idxer_, const K* keys, const size_t n_elements,
    const V* begin_value, V* values,
    const size_t concurrency = std::thread::hardware_concurrency()) {
  RETURN_ON_ASSERT(std::is_integral<K>::value, "K must be integral type.");
  parallel_for(
      static_cast<size_t>(0), n_elements,
      [&](const size_t index) {
        uint64_t v_index_ = 0;
        idxer_.get_index(keys[index], v_index_);
        values[v_index_] = begin_value[index];
      },
      concurrency);
  return Status::OK();
}

template <typename K, typename V>
Status build_values(
    PTHashIdxer<K, uint64_t>& idxer_,
    const std::shared_ptr<ArrowArrayType<K>>& keys, const V* begin_value,
    V* values, const size_t concurrency = std::thread::hardware_concurrency()) {
  parallel_for(
      static_cast<size_t>(0), static_cast<size_t>(keys->length()),
      [&](const size_t index) {
        uint64_t v_index_ = 0;
        idxer_.get_index(keys->GetView(index), v_index_);
        values[v_index_] = begin_value[index];
      },
      concurrency);
  return Status::OK();
}

template <typename K, typename V>
Status build_values(
    PTHashIdxer<K, uint64_t>& idxer_, const K* keys, const size_t n_elements,
    const V begin_value, V* values,
    const size_t concurrency = std::thread::hardware_concurrency()) {
  RETURN_ON_ASSERT(std::is_integral<K>::value, "K must be integral type.");
  parallel_for(
      static_cast<size_t>(0), n_elements,
      [&](const size_t index) {
        uint64_t v_index_ = 0;
        idxer_.get_index(keys->GetView(index), v_index_);
        values[v_index_] = begin_value + index;
      },
      concurrency);
  return Status::OK();
}

template <typename K, typename V>
Status build_values(
    PTHashIdxer<K, uint64_t>& idxer_,
    const std::shared_ptr<ArrowArrayType<K>>& keys, const V begin_value,
    V* values, const size_t concurrency = std::thread::hardware_concurrency()) {
  parallel_for(
      static_cast<size_t>(0), static_cast<size_t>(keys->length()),
      [&](const size_t index) {
        uint64_t v_index_ = 0;
        idxer_.get_index(keys->GetView(index), v_index_);
        values[v_index_] = begin_value + index;
      },
      concurrency);
  return Status::OK();
}

}  // namespace perfect_hash
}  // namespace detail

template <typename K, typename V>
class [[vineyard]] PerfectHashmap : public Registered<PerfectHashmap<K, V>> {
 public:
  using value_type = V;

  const V* find(const K& key) const {
    uint64_t value_index = 0;
    if (this->idxer_.get_index(key, value_index)) {
      return &ph_values_ptr_[value_index];
    }
    return nullptr;
  }

  size_t count(const K& key) const {
    return this->find(key) == nullptr ? 0 : 1;
  }

  size_t size() const { return num_elements_; }

  size_t bucket_count() const { return num_elements_; }

  float load_factor() const { return 1.0f; }

  bool empty() const { return num_elements_ == 0; }

  const V at(const K& key) const {
    auto found = this->find(key);
    if (found == nullptr) {
      throw std::out_of_range("Argument passed to at() was not in the map.");
    }
    return *found;
  }

  void PostConstruct(const ObjectMeta& meta) override {
    ph_values_ptr_ = reinterpret_cast<const V*>(ph_values_->data());
    idxer_.Init(const_cast<void*>(static_cast<const void*>(ph_->data())),
                ph_->size());
  }

 private:
  [[shared]] size_t num_elements_;
  [[shared]] std::shared_ptr<Object> ph_keys_;
  [[shared]] std::shared_ptr<Blob> ph_values_;
  [[shared]] std::shared_ptr<Blob> ph_;  // state

  const V* ph_values_ptr_ = nullptr;
  mutable PTHashIdxer<K, uint64_t> idxer_;

  friend class Client;
  friend class PerfectHashmapBaseBuilder<K, V>;
};

#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif

}  // namespace vineyard

#endif  // MODULES_BASIC_DS_HASHMAP_VINEYARD_MOD_

// vim: syntax=cpp
